package mill.bsp

import java.io.File
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

import ch.epfl.scala.bsp4j._
import ch.epfl.scala.{bsp4j => bsp}
import mill.api.{BuildProblemReporter, Problem}

import scala.collection.JavaConverters._
import scala.collection.concurrent
import scala.language.implicitConversions

/**
 * Specialized reporter that sends compilation diagnostics
 * for each problem it logs, either as information, warning or
 * error as well as task finish notifications of type `compile-report`.
 *
 * @param client              the client to send diagnostics to
 * @param targetId            the target id of the target whose compilation
 *                            the diagnostics are related to
 * @param taskId              a unique id of the compilation task of the target
 *                            specified by `targetId`
 * @param compilationOriginId optional origin id the client assigned to
 *                            the compilation request. Needs to be sent
 *                            back as part of the published diagnostics
 *                            as well as compile report
 */
class BspLoggedReporter(
    client: bsp.BuildClient,
    target: BuildTarget,
    taskId: TaskId,
    compilationOriginId: Option[String]
) extends BuildProblemReporter {
  var errors = new AtomicInteger(0)
  var warnings = new AtomicInteger(0)
  var infos = new AtomicInteger(0)
  var diagnosticMap: concurrent.Map[TextDocumentIdentifier, bsp.PublishDiagnosticsParams] =
    new ConcurrentHashMap[TextDocumentIdentifier, bsp.PublishDiagnosticsParams]().asScala

  override def logError(problem: Problem): Unit = {
    client.onBuildPublishDiagnostics(getDiagnostics(problem, target.getId, compilationOriginId))
    errors.incrementAndGet()
  }

  override def logInfo(problem: Problem): Unit = {
    client.onBuildPublishDiagnostics(getDiagnostics(problem, target.getId, compilationOriginId))
    infos.incrementAndGet()
  }

  // Obtains the parameters for sending diagnostics about the given Problem ( as well as
  // about all previous problems generated for the same text file ) related to the specified
  // targetId, incorporating the given optional originId ( generated by the client for the
  // compile request )
  //TODO: document that if the problem is a general information without a text document
  // associated to it, then the document field of the diagnostic is set to the uri of the target
  private[this] def getDiagnostics(
      problem: Problem,
      targetId: bsp.BuildTargetIdentifier,
      originId: Option[String]
  ): bsp.PublishDiagnosticsParams = {
    val diagnostic = getSingleDiagnostic(problem)
    val sourceFile = problem.position.sourceFile
    val textDocument = new TextDocumentIdentifier(
      sourceFile.getOrElse(None) match {
        case None => targetId.getUri
        case f: File => f.toURI.toString
      }
    )
    val params = new bsp.PublishDiagnosticsParams(
      textDocument,
      targetId,
      appendDiagnostics(textDocument, diagnostic).asJava,
      true
    )

    if (originId.nonEmpty) {
      params.setOriginId(originId.get)
    }
    diagnosticMap.put(textDocument, params)
    params
  }

  // Update the published diagnostics for the fiven text file by
  // adding the recently computed diagnostic to the list of
  // all previous diagnostics generated for the same file.
  private[this] def appendDiagnostics(
      textDocument: TextDocumentIdentifier,
      currentDiagnostic: Diagnostic
  ): List[Diagnostic] = {
    diagnosticMap.putIfAbsent(
      textDocument,
      new bsp.PublishDiagnosticsParams(
        textDocument,
        target.getId,
        List.empty[Diagnostic].asJava,
        true
      )
    )
    diagnosticMap(textDocument).getDiagnostics.asScala.toList ++ List(currentDiagnostic)
  }

  // Computes the diagnostic related to the given Problem
  private[this] def getSingleDiagnostic(problem: Problem): Diagnostic = {
    val pos = problem.position
    val line = pos.line.map(_ - 1) // Zinc's range starts at 1 whereas BSP 0
    val start = new bsp.Position(
      pos.startLine.orElse(line).getOrElse[Int](0),
      pos.startOffset.orElse(pos.pointer).getOrElse[Int](0)
    )
    val end = new bsp.Position(
      pos.endLine.orElse(line).getOrElse[Int](start.getLine.intValue()),
      pos.endOffset.orElse(pos.pointer).getOrElse[Int](start.getCharacter.intValue())
    )
    val diagnostic = new bsp.Diagnostic(new bsp.Range(start, end), problem.message)
    diagnostic.setCode(pos.lineContent)
    diagnostic.setSource("compiler from mill")
    diagnostic.setSeverity(
      problem.severity match {
        case mill.api.Info => bsp.DiagnosticSeverity.INFORMATION
        case mill.api.Error => bsp.DiagnosticSeverity.ERROR
        case mill.api.Warn => bsp.DiagnosticSeverity.WARNING
      }
    )
    diagnostic
  }

  override def logWarning(problem: Problem): Unit = {
    client.onBuildPublishDiagnostics(getDiagnostics(problem, target.getId, compilationOriginId))
    warnings.incrementAndGet()
  }

  override def printSummary(): Unit = {
    val taskFinishParams = new TaskFinishParams(taskId, getStatusCode)
    taskFinishParams.setEventTime(System.currentTimeMillis())
    taskFinishParams.setMessage(s"Compiled ${target.getDisplayName}")
    taskFinishParams.setDataKind(TaskDataKind.COMPILE_REPORT)
    val compileReport = new CompileReport(target.getId, errors.get, warnings.get)
    compilationOriginId match {
      case Some(id) => compileReport.setOriginId(id)
      case None =>
    }
    taskFinishParams.setData(compileReport)
    client.onBuildTaskFinish(taskFinishParams)
  }

  // Compute the compilation status code
  private[this] def getStatusCode: StatusCode = {
    if (errors.get > 0) StatusCode.ERROR else StatusCode.OK
  }
}
